#!/usr/bin/python

###     ###
# Imports #
###     ###

from io import BytesIO
import sys, os, time, re, json, datetime, ctypes, subprocess, plistlib, struct, base64, binascii

if os.name == "nt":
    # Windows
    import msvcrt
else:
    # Not Windows \o/
    import select

try:
    FMT_XML = plistlib.FMT_XML
except:
    FMT_XML = None

###            ###
# Helper Methods #
###            ###

def _check_py3():
    return True if sys.version_info >= (3, 0) else False

def _is_binary(fp):
    header = fp.read(32)
    fp.seek(0)
    return header[:8] == b'bplist00'

def _get_inst():
    if _check_py3():
        return (str)
    else:
        return (str, unicode)

###                             ###
# Deprecated Functions - Remapped #
###                             ###

def readPlist(pathOrFile):
    if not isinstance(pathOrFile, _get_inst()):
        return load(pathOrFile)
    with open(pathOrFile, "rb") as f:
        return load(f)

def writePlist(value, pathOrFile):
    if not isinstance(pathOrFile, _get_inst()):
        return dump(value, pathOrFile, fmt=FMT_XML, sort_keys=True, skipkeys=False)
    with open(pathOrFile, "wb") as f:
        return dump(value, f, fmt=FMT_XML, sort_keys=True, skipkeys=False)

###                ###
# Remapped Functions #
###                ###

def load(fp, fmt=None, use_builtin_types=True, dict_type=dict):
    if _check_py3():
        return plistlib.load(fp, fmt=fmt, use_builtin_types=use_builtin_types, dict_type=dict_type)
    elif not _is_binary(fp):
        return plistlib.readPlist(fp)
    else:
        return readBinaryPlistFile(fp)

def loads(value, fmt=None, use_builtin_types=True, dict_type=dict):
    if isinstance(value, _get_inst()):
        # We were sent a string - let's encode it to some utf-8 bytes for fun!
        value = value.encode("utf-8")
    fp = BytesIO(value)
    if _check_py3():
        return plistlib.load(fp, fmt=fmt, use_builtin_types=use_builtin_types, dict_type=dict_type)
    elif not _is_binary(fp):
        return plistlib.readPlistFromString(value)
    else:
        return readBinaryPlistFile(fp)

def dump(value, fp, fmt=FMT_XML, sort_keys=True, skipkeys=False):
    if _check_py3():
        plistlib.dump(value, fp, fmt=fmt, sort_keys=sort_keys, skipkeys=skipkeys)
    else:
        plistlib.writePlist(value, fp)
    
def dumps(value, fmt=FMT_XML, skipkeys=False):
    if _check_py3():
        return plistlib.dumps(value, fmt=fmt, skipkeys=skipkeys).decode("utf-8")
    else:
        return plistlib.writePlistToString(value).encode("utf-8")


###                        ###
# Binary Plist Stuff For Py2 #
###                        ###

# timestamp 0 of binary plists corresponds to 1/1/2001 (year of Mac OS X 10.0), instead of 1/1/1970.
MAC_OS_X_TIME_OFFSET = (31 * 365 + 8) * 86400

class InvalidFileException(ValueError):
    def __str__(self):
        return "Invalid file"
    def __unicode__(self):
        return "Invalid file"

def readBinaryPlistFile(in_file):
    """
    Read a binary plist file, following the description of the binary format: http://opensource.apple.com/source/CF/CF-550/CFBinaryPList.c
    Raise InvalidFileException in case of error, otherwise return the root object, as usual

    Original patch diffed here:  https://bugs.python.org/issue14455
    """
    in_file.seek(-32, os.SEEK_END)
    trailer = in_file.read(32)
    if len(trailer) != 32:
        return InvalidFileException()
    offset_size, ref_size, num_objects, top_object, offset_table_offset = struct.unpack('>6xBB4xL4xL4xL', trailer)
    in_file.seek(offset_table_offset)
    object_offsets = []
    offset_format = '>' + {1: 'B', 2: 'H', 4: 'L', 8: 'Q', }[offset_size] * num_objects
    ref_format = {1: 'B', 2: 'H', 4: 'L', 8: 'Q', }[ref_size]
    int_format = {0: (1, '>B'), 1: (2, '>H'), 2: (4, '>L'), 3: (8, '>Q'), }
    object_offsets = struct.unpack(offset_format, in_file.read(offset_size * num_objects))
    def getSize(token_l):
        """ return the size of the next object."""
        if token_l == 0xF:
            m = ord(in_file.read(1)) & 0x3
            s, f = int_format[m]
            return struct.unpack(f, in_file.read(s))[0]
        return token_l
    def readNextObject(offset):
        """ read the object at offset. May recursively read sub-objects (content of an array/dict/set) """
        in_file.seek(offset)
        token = in_file.read(1)
        token_h, token_l = ord(token) & 0xF0, ord(token) & 0x0F #high and low parts 
        if token == '\x00':
            return None
        elif token == '\x08':
            return False
        elif token == '\x09':
            return True
        elif token == '\x0f':
            return ''
        elif token_h == 0x10: #int
            result = 0
            for k in xrange((2 << token_l) - 1):
                result = (result << 8) + ord(in_file.read(1))
            return result
        elif token_h == 0x20: #real
            if token_l == 2:
                return struct.unpack('>f', in_file.read(4))[0]
            elif token_l == 3:
                return struct.unpack('>d', in_file.read(8))[0]
        elif token_h == 0x30: #date
            f = struct.unpack('>d', in_file.read(8))[0]
            return datetime.datetime.utcfromtimestamp(f + MAC_OS_X_TIME_OFFSET)
        elif token_h == 0x80: #data
            s = getSize(token_l)
            return in_file.read(s)
        elif token_h == 0x50: #ascii string
            s = getSize(token_l)
            return in_file.read(s)
        elif token_h == 0x60: #unicode string
            s = getSize(token_l)
            return in_file.read(s * 2).decode('utf-16be')
        elif token_h == 0x80: #uid
            return in_file.read(token_l + 1)
        elif token_h == 0xA0: #array
            s = getSize(token_l)
            obj_refs = struct.unpack('>' + ref_format * s, in_file.read(s * ref_size))
            return map(lambda x: readNextObject(object_offsets[x]), obj_refs)
        elif token_h == 0xC0: #set
            s = getSize(token_l)
            obj_refs = struct.unpack('>' + ref_format * s, in_file.read(s * ref_size))
            return set(map(lambda x: readNextObject(object_offsets[x]), obj_refs))
        elif token_h == 0xD0: #dict
            result = {}
            s = getSize(token_l)
            key_refs = struct.unpack('>' + ref_format * s, in_file.read(s * ref_size))
            obj_refs = struct.unpack('>' + ref_format * s, in_file.read(s * ref_size))
            for k, o in zip(key_refs, obj_refs):
                key = readNextObject(object_offsets[k])
                obj = readNextObject(object_offsets[o])
                result[key] = obj
            return result
        raise InvalidFileException()
    return readNextObject(object_offsets[top_object])

class Utils:

    def __init__(self, name = "Python Script"):
        self.name = name
        # Init our colors before we need to print anything
        cwd = os.getcwd()
        os.chdir(os.path.dirname(os.path.realpath(__file__)))
        if os.path.exists("colors.json"):
            self.colors_dict = json.load(open("colors.json"))
        else:
            self.colors_dict = {}
        os.chdir(cwd)
        
    def check_path(self, path):
        # Loop until we either get a working path - or no changes
        count = 0
        while count < 100:
            count += 1
            if not len(path):
                # We uh.. stripped out everything - bail
                return None
            if os.path.exists(path):
                # Exists!
                return os.path.abspath(path)
            # Check quotes first
            if (path[0] == '"' and path[-1] == '"') or (path[0] == "'" and path[-1] == "'"):
                path = path[1:-1]
                continue
            # Check for tilde
            if path[0] == "~":
                test_path = os.path.expanduser(path)
                if test_path != path:
                    # We got a change
                    path = test_path
                    continue
            # If we have no spaces to trim - bail
            if not (path[0] == " " or path[0] == "  ") and not(path[-1] == " " or path[-1] == " "):
                return None
            # Here we try stripping spaces/tabs
            test_path = path
            t_count = 0
            while t_count < 100:
                t_count += 1
                t_path = test_path
                while len(t_path):
                    if os.path.exists(t_path):
                        return os.path.abspath(t_path)
                    if t_path[-1] == " " or t_path[-1] == "    ":
                        t_path = t_path[:-1]
                        continue
                    break
                if test_path[0] == " " or test_path[0] == " ":
                    test_path = test_path[1:]
                    continue
                break
            # Escapes!
            test_path = "\\".join([x.replace("\\", "") for x in path.split("\\\\")])
            if test_path != path and not (path[0] == " " or path[0] == "  "):
                path = test_path
                continue
            if path[0] == " " or path[0] == "  ":
                path = path[1:]
        return None

    def cls(self):
    	os.system('cls' if os.name=='nt' else 'clear')

def show_help():
    help_text = "k2p - KextsToPatch Script\nby CorpNewt\n\n"
    help_text += "Usage:  k2p /path/to/plist [verb] [options]\n\n"
    help_text += "Available verbs:\n\n"
    help_text += "            Add  -  Adds the passed info to the passed plist - whether it exists or not\n"
    help_text += "         Update  -  Updates an existing patch if exists, or adds it\n"
    help_text += "         Remove  -  Removes all instances of the passed info from the passed plist\n"
    help_text += "         Enable  -  Enables the passed patch if it exists\n"
    help_text += "        Disable  -  Disables the passed patch if it exists\n"
    help_text += "       Disabled  -  Returns whether the passed patch is disabled in the plist\n"
    help_text += "            Has  -  Returns whether the plist contains the passed patch\n"
    help_text += "    Query [arg]  -  Return the value of the passed arg\n"
    help_text += "           List  -  Lists all the patches - can only be used with -Type\n\n"
    help_text += "Available options:\n\n"
    help_text += "          -Type  -  The type of data passed - hex or base64 (default is hex)\n"
    help_text += "          -Find  -  The data to find (Required)\n"
    help_text += "       -Replace  -  The data to replace with (Required)\n"
    help_text += "          -Name  -  The kext to patch (Required)\n"
    help_text += "       -Comment  -  A short description of the patch\n"
    help_text += "      -Disabled  -  True/False - sets whether the passed patch is disabled\n"
    help_text += "       -MatchOS  -  10.xx.x format - sets the required OS for the patch\n"
    help_text += "-InfoPlistPatch  -  True/False - sets whether the passed patch is an Info.plist patch\n"
    help_text += "        -Silent  -  If this option is passed, output is suppressed where possible\n"
    help_text += "       -Retcode  -  Returns any true/false/error output as 0, 1, or 2 respectively\n\n"
    help_text += "Example:\n\n"
    help_text += "k2p ~/Desktop/config.plist add -find 0x00000000 -replace 0x00000001 -name AppleHDA\n\n"
    print(help_text)

def parse_args(arg_list):
    args = list(arg_list)
    # Got args - let's parse them
    # Check the path
    path = u.check_path(args[1])
    if not path:
        print("Could not locate plist:\n\n\"{}\"".format(args[1]))
        exit(1)
    # Got a path - check the verb
    if not args[2].lower() in verbs:
        print("Unknown verb:\n\n\"{}\"".format(args[2]))
        exit(1)
    # Got the verb - let's parse the args
    arg_dict = {}
    # Set the path and verb
    arg_dict["Path"] = path
    arg_dict["Verb"] = args[2].lower()
    # Remove the first 3 items in the args list
    if arg_dict["Verb"] == "query":
        arg_dict["Term"] = args[3]
        args = args[4:]
    else:
        args = args[3:]
    all_opts = r_opts + opts + nv_opts
    var_opts = r_opts + opts
    while len(args):
        arg = args.pop(0)
        if not arg.lower() in [x.lower() for x in all_opts]:
            print("Unknown argument:\n\n\"{}\"".format(arg))
            exit(1)
        # Got a known arg - set it
        argu = next(x for x in all_opts if x.lower() == arg.lower())
        val = True
        if argu in var_opts:
            # Needs a variable
            try:
                val = args.pop(0)
            except:
                print("Missing value for argument:\n\n\"{}\"".format(arg))
                exit(1)
        arg_dict[argu.replace("-", "")] = val
    # Verify the required args are passed
    if arg_dict["Verb"] != "list" and not all([arg_dict.get(x, None) for x in [y.replace("-", "") for y in r_opts]]):
        print("Missing required arguments.")
        exit(1)
    return arg_dict

def _check_hex(hex_string):
    # Remove 0x/0X
    hex_string = hex_string.replace("0x", "").replace("0X", "")
    hex_string = re.sub(r'[^0-9A-Fa-f]+', '', hex_string)
    return hex_string

def hextobytes(hex_string):
    hex_string = _check_hex(hex_string).encode("utf-8")
    return binascii.unhexlify(hex_string)

def get_entry(args, k2p):
    return next((x for x in k2p if all(x[y] == args[y] for y in ["Find","Replace","Name"])), None)

def rprint(text, code, args, to_exit=False):
    r = args.get("Retcode",False)
    if r:
        print(code)
    else:
        print(text)
    if to_exit:
        exit(code)

def process(args):
    # This is where shit really happens
    # First let's load the file
    try:
        with open(args["Path"], "rb") as f:
            plist_data = load(f)
    except Exception as e:
        rprint("Plist is malformed:\n\n" + str(e), 2, args, True)
    data_type = args.get("Type", "hex")
    # Check for KernelAndKextPatches -> KextsToPatch
    if not type(plist_data.get("KernelAndKextPatches", None)) in [plistlib._InternalDict, dict]:
        # Either it doesn't exist - or it's not the right type
        # Set as a blank array
        plist_data["KernelAndKextPatches"] = {"KextsToPatch":[]}
    if not type(plist_data["KernelAndKextPatches"].get("KextsToPatch", None)) is list:
        plist_data["KernelAndKextPatches"]["KextsToPatch"] = []
    k2p = plist_data["KernelAndKextPatches"]["KextsToPatch"]
    # Check if we list first - then do other shit
    verb = args.get("Verb", "").lower()
    if verb == "list":
        # List all the entries - only worried about type here for data
        k2p_text = "Listing {} KextsToPatch {}:\n\n".format(len(k2p), "entry" if len(k2p)==1 else "entries")
        for patch in k2p:
            pad = len(max(patch, key=len))
            for key in sorted(patch, key=lambda x: x.lower()):
                if (_check_py3() and isinstance(patch[key],(bytes, plistlib.Data))) or isinstance(patch[key],plistlib.Data):
                    if isinstance(patch[key],plistlib.Data):
                        # Let's format as string
                        patch[key] = patch[key].data
                    b = patch[key]
                    if data_type.lower() == "hex":
                        patch[key] = binascii.hexlify(b).decode("utf-8")
                    else:
                        b64b = base64.b64decode(b)
                        patch[key] = base64.b64encode(b).decode("utf-8")
                k2p_text+="{}{} : {}\n".format(" "*(pad-len(key)), key, patch[key])
            if len(patch):
                k2p_text+="\n"
        rprint(k2p_text, 0, args, True)
    if data_type.lower() == "hex":
        try:
            args["Find"] = hextobytes(args["Find"])
            args["Replace"] = hextobytes(args["Replace"])
        except:
            rprint("Malformed hex - aborting.", 2, args, True)
    else:
        try:
            args["Find"] = base64.b64decode(args["Find"])
            args["Replace"] = base64.b64decode(args["Replace"])
        except:
            rprint("Malformed base64 - aborting.", 2, args, True)
    # Break out and check the verbs
    entry = get_entry(args, k2p)
    if verb == "has":
        rprint("True" if entry else "False", 0 if entry else 1, args, True)
    elif verb == "query":
        if not entry:
            rprint("Entry doesn't exist.", 1, args, True)
        if not args.get("Term",None) in entry:
            rprint("Key not found:\n\n\"{}\"".format(args["Term"]), 1, args, True)
        rprint("{}\n{}".format(type(entry[args["Term"]]),entry[args["Term"]]), 0, args, True)
    elif verb == "disabled":
        val = 1 if not entry else (0 if entry.get("Disabled",False) else 1)
        rprint("False" if val else "True", val, args, True)
    elif verb == "enable":
        entry["Disabled"] = False
    elif verb == "disable":
        entry["Disabled"] = True
    elif verb == "remove":
        if not entry:
            rprint("Entry doesn't exist.", 1, args, True)
        placeholder = entry
        count = 0
        while True:
            if not entry:
                break
            count += 1
            k2p.remove(entry)
            entry = get_entry(args,k2p)
        if not args.get("Silent",False):
            rprint("Removed {} {} of:\n\n{}".format(count, "instance" if count == 1 else "instances", "\n".join(dumps(placeholder).split("\n")[3:-2])), 0, args)
    elif verb in ["add","update"]:
        if entry and verb == "update":
            # Remove first
            k2p.remove(entry)
        new = {}
        # Check the optional values
        for key in args:
            if key in ["Type","Path","Verb","Silent"]:
                continue
            if key.lower() == "disabled":
                # We need to convert to bool
                if args[key].lower() in ["y", "yes", "true"]:
                    args[key] = True
                else:
                    args[key] = False
            new[key] = args[key]
        # Clean up the data sections
        new["Find"] = plistlib.Data(new["Find"])
        new["Replace"] = plistlib.Data(new["Replace"])
        k2p.append(new)
        if not args.get("Silent", False):
            rprint("Added:\n\n{}".format("\n".join(dumps(new).split("\n")[3:-2])), 0, args)
    writePlist(plist_data, args["Path"])
    exit(0)

# Global vars
verbs   = ["add","update","remove","enable","disable","disabled","has","query","list"]
r_opts  = ["-Find","-Replace","-Name"]
opts    = ["-Type","-Comment","-Disabled","-MatchOS","-InfoPlistPatch"]
nv_opts = ["-Silent","-Retcode"] # Options that don't require a second var

if __name__ == "__main__":
    u = Utils()
    # Let's gather our args
    # Need our path, config path, verb, and at least 3 options
    if len(sys.argv) < 3 or (len(sys.argv) < 6 and not sys.argv[2].lower() == "list"):
        # Show help and exit
        show_help()
        exit(0)
    args = parse_args(sys.argv)
    process(args)